#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import torch
import numpy as np

###################################################
# round with gradient - assumed to have unit gradient.
# note: this rounds towards even (bankers round / unbiased round - like numpy / pytorch).
class RoundG(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, dither=0.0):
        rand_val = torch.zeros_like(x).uniform_(-dither,dither) if dither else 0.0
        y = torch.round(x+rand_val)
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x, dither=0.0):
        return g.op("RoundG", x)


# round with gradient - assumed to have unit gradient.
# this rounds to -infinity or +infinity depending on sign -
# this is also reasonably unbiased, although not the best.,
# note: this is different from pytorch or numpy round - which do round towards even.
class RoundSymG(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        dither = 0.0
        rnd = (-0.5)*(x<0).float() + (+0.5)*(x>=0).float()
        rand_val = torch.zeros_like(x).uniform_(-dither,dither) if dither else 0.0
        y = (x+rnd+rand_val).int().float()
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x):
        # dither = 0.0
        return g.op("RoundSymG", x)

# round with gradient - assumed to have unit gradient.
# this rounds to +infinity, like typical fixed point hardware.
class RoundUpG(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        dither = 0.0
        rand_val = torch.zeros_like(x).uniform_(-dither,dither) if dither else 0.0
        y = torch.floor(x+0.5+rand_val)
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x):
        # dither = 0.0
        return g.op("RoundUpG", x)


# round with gradient - assumed to have unit gradient.
class Round2G(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.pow(2,torch.round(torch.log2(x)))
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x):
        return g.op("Round2G", x)


class CeilG(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.ceil(x)
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x):
        return g.op("CeilG", x)


# ceil to the next power of 2 - with unit gradient
class Ceil2G(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.pow(2,torch.ceil(torch.log2(x)))
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x):
        return g.op("Ceil2G", x)


# floor to the below power of 2 - with unit gradient
class Floor2G(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = torch.pow(2,torch.floor(torch.log2(x)))
        return y

    @staticmethod
    def backward(ctx, dy):
        return dy

    @staticmethod
    def symbolic(g, x):
        return g.op("Floor2G", x)


class QuantizeDequantizeG(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale_tensor, width_min, width_max, power2, axis, round_type='round_up'):
        # apply quantization
        y, x_scaled_round = QuantizeDequantizeG.quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type)
        # save for backward
        ctx.save_for_backward(x, scale_tensor, x_scaled_round)
        ctx.width_min, ctx.width_max, ctx.power2, ctx.round_type = width_min, width_max, power2, round_type
        return y


    @staticmethod
    def quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
        # clip values need ceil2 and scale values need floor2
        scale_tensor = Floor2G.apply(scale_tensor) if power2 else scale_tensor
        x_scaled = (x * scale_tensor)

        # round
        if round_type == 'round_up':    # typically for activations
            rand_val = 0.5
            x_scaled_round = torch.floor(x_scaled+rand_val)
        elif round_type == 'round_sym': # typically for weights
            rand_val = (-0.5) * (x < 0).float() + (+0.5) * (x >= 0).float()
            x_scaled_round = (x_scaled+rand_val).int().float()
        else:
            x_scaled_round = torch.round(x_scaled)
        #
        # invert the scale
        scale_inv = scale_tensor.pow(-1.0)
        # clamp
        x_clamp = torch.clamp(x_scaled_round, width_min, width_max)
        y = x_clamp * scale_inv
        return y, x_scaled_round


    @staticmethod
    def backward(ctx, dy):
        # use numerical gradient as analytical expression is difficult.
        # this includes gradient of scale, round and clip
        x, scale_tensor, x_scaled_round = ctx.saved_tensors
        width_min, width_max, power2, round_type = ctx.width_min, ctx.width_max, ctx.power2, ctx.round_type

        # gradient w.r.t. x
        x_in_range = (x_scaled_round>=width_min).float()*(x_scaled_round<=width_max).float()
        dx = dy * x_in_range

        # gradient w.r.t. s - use numerical gradient scale_tensor can be discrete (power of 2)
        scale_dl = (scale_tensor / 2)
        scale_l = (scale_tensor - scale_dl)
        scale_dh = (scale_tensor)
        scale_h = (scale_tensor + scale_dh)
        y_l, _ = QuantizeDequantizeG.quant_dequant(ctx, x, scale_l, width_min, width_max, power2, round_type)
        y_m, _ = QuantizeDequantizeG.quant_dequant(ctx, x, scale_tensor, width_min, width_max, power2, round_type)
        y_h, _ = QuantizeDequantizeG.quant_dequant(ctx, x, scale_h, width_min, width_max, power2, round_type)

        ds_l = (y_m - y_l) / scale_dl
        ds_h = (y_h - y_m) / scale_dh
        ds_local = (ds_l + ds_h) / 2
        ds = dy * ds_local

        # return
        return dx, ds, None, None, None, None, None


    @staticmethod
    def symbolic(g,  x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
        return g.op("QuantizeDequantizeG",  x, scale_tensor)


class TorchQuantizeDequantizeG(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale_tensor, width_min, width_max, power2, axis, round_type='round_up'):
        # apply quantization
        if scale_tensor.dim()>0:
            device = x.device
            axis_size = int(x.size(axis))
            scale_tensor = scale_tensor.reshape(axis_size)
            zero_point = torch.zeros(axis_size).to(device=device, dtype=torch.long)
            y = torch.fake_quantize_per_channel_affine(x, scale=scale_tensor, zero_point=zero_point, axis=axis,
                    quant_min=int(width_min), quant_max=int(width_max))
        else:
            y = torch.fake_quantize_per_tensor_affine(x, scale=float(scale_tensor), zero_point=0,
                    quant_min=int(width_min), quant_max=int(width_max))
        #
        return y


    @staticmethod
    def symbolic(g,  x, scale_tensor, width_min, width_max, power2, round_type='round_up'):
        return g.op("TorchQuantizeDequantizeG",  x, scale_tensor)